{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ba80db7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n",
    "DatasetRoot = os.path.join(r\"D:\\deepdotdev\\deepdot-vision\\datasets\\FashionMNIST\\raw\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1742b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ProjectRoot = os.path.join(r\"./../../../deepdot_vision/\")\n",
    "import sys\n",
    "sys.path.insert(0,ProjectRoot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4904d001",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7a8347f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tools import add_project_root_2_sys_path, yaml_config_reader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "15a1e7e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import get_dataloader_train\n",
    "from models import MnistNet\n",
    "from data import FashionMNIST_CLASS_LIST\n",
    "from data import TBV\n",
    "from data import FashionMNISTTrainDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "eca86f66",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2051 60000 28 28\n",
      "2049 60000\n"
     ]
    }
   ],
   "source": [
    "FashionMnistFolder = os.path.join(r\"D:\\deepdotdev\\deepdot-vision\\datasets\\FashionMNIST\\raw\")\n",
    "dataloader_train = get_dataloader_train(FashionMnistFolder, image_size=(28, 28), batch_size=1010)\n",
    "\n",
    "net = MnistNet()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "classes = FashionMNIST_CLASS_LIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0cffaed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "tb_vis = TBV(r\"D:\\deepdotdev\\deepdot-vision\\runs\")\n",
    "\n",
    "dataiter = iter(dataloader_train)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "tb_vis.add_image_batch(images, 0, \"FashionImages\")\n",
    "tb_vis.add_graph_image(net, images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5f5af41a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current epoch is 0 with loss 2.225588798522949\n",
      "current epoch is 1 with loss 2.141735315322876\n",
      "current epoch is 2 with loss 1.9359245300292969\n",
      "current epoch is 3 with loss 1.5400537252426147\n",
      "current epoch is 4 with loss 1.1913151741027832\n",
      "current epoch is 5 with loss 0.9861752986907959\n",
      "current epoch is 6 with loss 0.8830885291099548\n",
      "current epoch is 7 with loss 0.8284486532211304\n",
      "current epoch is 8 with loss 0.7958691716194153\n",
      "current epoch is 9 with loss 0.7726806998252869\n",
      "current epoch is 10 with loss 0.7525576949119568\n",
      "current epoch is 11 with loss 0.7331835031509399\n",
      "current epoch is 12 with loss 0.7180822491645813\n",
      "current epoch is 13 with loss 0.7038794159889221\n",
      "current epoch is 14 with loss 0.6906247735023499\n",
      "current epoch is 15 with loss 0.6822994947433472\n",
      "current epoch is 16 with loss 0.6715354323387146\n",
      "current epoch is 17 with loss 0.6603990793228149\n",
      "current epoch is 18 with loss 0.6535881161689758\n",
      "current epoch is 19 with loss 0.6465492844581604\n",
      "current epoch is 20 with loss 0.6375810503959656\n",
      "current epoch is 21 with loss 0.6296043992042542\n",
      "current epoch is 22 with loss 0.6254695057868958\n",
      "current epoch is 23 with loss 0.6160691380500793\n",
      "current epoch is 24 with loss 0.6093304753303528\n",
      "current epoch is 25 with loss 0.6037760972976685\n",
      "current epoch is 26 with loss 0.5972257256507874\n",
      "current epoch is 27 with loss 0.5908771753311157\n",
      "current epoch is 28 with loss 0.5841690897941589\n",
      "current epoch is 29 with loss 0.5772621035575867\n",
      "current epoch is 30 with loss 0.5727708339691162\n",
      "current epoch is 31 with loss 0.5690277218818665\n",
      "current epoch is 32 with loss 0.5649452805519104\n",
      "current epoch is 33 with loss 0.5566872954368591\n",
      "current epoch is 34 with loss 0.5523911118507385\n",
      "current epoch is 35 with loss 0.5485659241676331\n",
      "current epoch is 36 with loss 0.5456545352935791\n",
      "current epoch is 37 with loss 0.5393737554550171\n",
      "current epoch is 38 with loss 0.5335840582847595\n",
      "current epoch is 39 with loss 0.5287690758705139\n",
      "current epoch is 40 with loss 0.5239912867546082\n",
      "current epoch is 41 with loss 0.5218101739883423\n",
      "current epoch is 42 with loss 0.5216407179832458\n",
      "current epoch is 43 with loss 0.5142799019813538\n",
      "current epoch is 44 with loss 0.510608434677124\n",
      "current epoch is 45 with loss 0.5057174563407898\n",
      "current epoch is 46 with loss 0.5017272233963013\n",
      "current epoch is 47 with loss 0.50123530626297\n",
      "current epoch is 48 with loss 0.4949646294116974\n",
      "current epoch is 49 with loss 0.49330440163612366\n",
      "current epoch is 50 with loss 0.4904805123806\n",
      "current epoch is 51 with loss 0.48659688234329224\n",
      "current epoch is 52 with loss 0.4841432273387909\n",
      "current epoch is 53 with loss 0.480241596698761\n",
      "current epoch is 54 with loss 0.47658413648605347\n",
      "current epoch is 55 with loss 0.47500139474868774\n",
      "current epoch is 56 with loss 0.4702754020690918\n",
      "current epoch is 57 with loss 0.46744224429130554\n",
      "current epoch is 58 with loss 0.4668739438056946\n",
      "current epoch is 59 with loss 0.46262800693511963\n",
      "current epoch is 60 with loss 0.46182167530059814\n",
      "current epoch is 61 with loss 0.4606652557849884\n",
      "current epoch is 62 with loss 0.45734959840774536\n",
      "current epoch is 63 with loss 0.45493385195732117\n",
      "current epoch is 64 with loss 0.45091158151626587\n",
      "current epoch is 65 with loss 0.4578610956668854\n",
      "current epoch is 66 with loss 0.4471379816532135\n",
      "current epoch is 67 with loss 0.4451156258583069\n",
      "current epoch is 68 with loss 0.4447990357875824\n",
      "current epoch is 69 with loss 0.4454631805419922\n",
      "current epoch is 70 with loss 0.441580593585968\n",
      "current epoch is 71 with loss 0.436247855424881\n",
      "current epoch is 72 with loss 0.4345608949661255\n",
      "current epoch is 73 with loss 0.4320923089981079\n",
      "current epoch is 74 with loss 0.4301767349243164\n",
      "current epoch is 75 with loss 0.4302312135696411\n",
      "current epoch is 76 with loss 0.42634356021881104\n",
      "current epoch is 77 with loss 0.4283297657966614\n",
      "current epoch is 78 with loss 0.42330610752105713\n",
      "current epoch is 79 with loss 0.421547532081604\n",
      "current epoch is 80 with loss 0.4258638620376587\n",
      "current epoch is 81 with loss 0.4205021560192108\n",
      "current epoch is 82 with loss 0.4157060384750366\n",
      "current epoch is 83 with loss 0.41547372937202454\n",
      "current epoch is 84 with loss 0.4146541357040405\n",
      "current epoch is 85 with loss 0.41233590245246887\n",
      "current epoch is 86 with loss 0.4083421528339386\n",
      "current epoch is 87 with loss 0.40566661953926086\n",
      "current epoch is 88 with loss 0.40767475962638855\n",
      "current epoch is 89 with loss 0.4056885838508606\n",
      "current epoch is 90 with loss 0.4032210111618042\n",
      "current epoch is 91 with loss 0.4001581072807312\n",
      "current epoch is 92 with loss 0.3994734585285187\n",
      "current epoch is 93 with loss 0.39932864904403687\n",
      "current epoch is 94 with loss 0.3967350125312805\n",
      "current epoch is 95 with loss 0.3944273293018341\n",
      "current epoch is 96 with loss 0.39221513271331787\n",
      "current epoch is 97 with loss 0.3942088782787323\n",
      "current epoch is 98 with loss 0.3878207206726074\n",
      "current epoch is 99 with loss 0.3894926905632019\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(100):\n",
    "    loss_mean = 0.\n",
    "    loss_sum_epoch = 0.\n",
    "    loss_count = 0\n",
    "    correct = 0.\n",
    "    total = 0.\n",
    "\n",
    "    net.train()\n",
    "\n",
    "    for i, data in enumerate(dataloader_train):\n",
    "\n",
    "        # forward\n",
    "        inputs, labels = data\n",
    "        outputs = net(inputs)\n",
    "\n",
    "        # Compute loss\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(outputs, labels)\n",
    "        \n",
    "        loss_sum_epoch += loss\n",
    "        loss_count += 1\n",
    "\n",
    "        # backward\n",
    "        loss.backward()\n",
    "\n",
    "        # updata weights\n",
    "        optimizer.step()\n",
    "    \n",
    "    loss_mean = loss_sum_epoch / loss_count\n",
    "    tb_vis.add_loss_step(loss_mean, epoch, \"LossValues\")\n",
    "    print(\"current epoch is {} with loss {}\".format(epoch, loss_mean))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea95e42b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f15b61a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ff09532",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbab9d7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527ab132",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a12f8d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e683c96e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
